In [607]:
%load_ext autoreload
%autoreload 2
The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload
In [608]:
from sklearn.datasets import make_moons
from sklearn.preprocessing import normalize
from matplotlib import pyplot as plt
import matplotlib.cm as cm
import numpy as np
import torch
import matplotlib
import itertools
In [609]:
FIG_SIZE = 4

matplotlib.rcParams['figure.figsize'] = (FIG_SIZE, FIG_SIZE)

Tools¶

In [610]:
def scatter(X, title=None, ax=None, show_axes=False):
    if ax is None:
        ax = plt.gca()
    ax.scatter(X[:, 0], X[:, 1], s=10)
    if title is not None:
        ax.set_title(title)
    if not show_axes:
        ax.set_xticklabels([])
        ax.set_yticklabels([])

def plot(X, y, title=None, ax=None, show_axes=False, label=None):
    if ax is None:
        ax = plt.gca()
    ax.plot(X, y, label=label)
    if title is not None:
        ax.set_title(title)
    if not show_axes:
        ax.set_xticklabels([])
        ax.set_yticklabels([])

Toy dataset¶

$X \in [0, 1]^{d,N}$

In [611]:
N_SAMPLES = 1000

# moons toy dataset
dataset = make_moons(n_samples=N_SAMPLES)
X, _ = dataset
X = torch.from_numpy(X).float()

# Circle toy dataset
# theta = torch.linspace(0, 2*torch.pi, 1000).unsqueeze(1)
# X = torch.cat([torch.cos(theta), torch.sin(theta)], dim=1)

X = (X - X.mean(axis=0)) / X.std(axis=0) # center and normalize

scatter(X, title="Toy dataset", show_axes=True)

Understand the model formulation¶

A DPDM (Denoising Diffusion Probabilistic Model) is described as follows:

$$ x_{t} = \sqrt{1 - \beta_t}x_{t-1} + \sqrt{\beta_t} \epsilon_{t} $$

with $\epsilon_{t} \sim \mathcal{N}(0, 1)$ and $\beta_t \in [0, 1]$.

A global and more intuitive formalization will be:

$$ x_{t} = (1 - \beta_t)^px_{t-1} + {\beta_t}^p \epsilon_{t} $$

with $p \in \mathbb{R}_+$.

Setting $p = \frac{1}{2}$ allows us to compute $x_t$ from $x_0$ easily.¶

$$ \begin{align} x_t &= \sqrt{1 - \beta_t}x_{t-1} + \sqrt{\beta_t} \epsilon_{t} \\ &= \sqrt{\alpha_t}x_{t-1} + \sqrt{1 - \alpha_t} \epsilon_{t} \\ &= \sqrt{\alpha_t}\left(\sqrt{\alpha_{t-1}}x_{t-2} + \sqrt{1 - \alpha_{t-1}} \epsilon_{t-1} \right) + \sqrt{1 - \alpha_t} \epsilon_{t} \\ &= \sqrt{\alpha_t}\sqrt{\alpha_{t-1}}x_{t-2} + \sqrt{\alpha_t}\sqrt{1 - \alpha_{t-1}} \epsilon_{t-1} + \sqrt{1 - \alpha_t} \epsilon_{t} \\ &= \sqrt{\alpha_t}\sqrt{\alpha_{t-1}}x_{t-2} + \sqrt{\alpha_t(1 - \alpha_{t-1}) + 1 - \alpha_t} \epsilon_{t, \, t-1} \\ &= \sqrt{\alpha_t}\sqrt{\alpha_{t-1}}x_{t-2} + \sqrt{1 - \alpha_t\alpha_{t-1}} \epsilon_{t, \, t-1} \\ &= \dots \\ &= \sqrt{\overline{\alpha_t}}x_{0} + \sqrt{1 - \overline{\alpha_t}} \epsilon \end{align} $$
  1. Linear scaling on $x_0$ with respect to $\overline{\alpha_t'}$ by setting $\overline{\alpha_t'} = \sqrt{\overline{\alpha_t}} \implies \overline{\alpha_t} = \overline{\alpha_t'}^2$:
$$ x_{t} = \overline{\alpha_t'} x_{0} + \sqrt{1 - \overline{\alpha_t'}^2} \epsilon $$
  1. Linear scaling on $\epsilon$ with respect to $\overline{\alpha_t'}$ by setting $1 - \overline{\alpha_t'} = \sqrt{1 - \overline{\alpha_t}} \implies \overline{\alpha_t} = 1 - (1 - \overline{\alpha_t'})^2$:
$$ x_{t} = \sqrt{1 - (1 - \overline{\alpha_t'})^2}x_{0} + (1 - \overline{\alpha_t'}) \epsilon $$
  1. Sqrt interpolation (model):
$$ x_{t} = \sqrt{\overline{\alpha_t}}x_{0} + \sqrt{1 - \overline{\alpha_t}} \epsilon $$
  1. Linear convex interpolation:
$$ x_{t} = \overline{\alpha_t}x_{0} + (1 - \overline{\alpha_t}) \epsilon $$

Bellow, we plot using an increasing uniformly sequence of $\alpha_t$.

In [612]:
interpolation_list = [
    lambda alpha: ("Linear scaling on $x_0$ (1)", alpha, (1 - alpha**2)**0.5),
    lambda alpha: ("Linear scaling on $\epsilon$ (2)", (1 - (1 - alpha)**2)**0.5, 1 - alpha),
    lambda alpha: ("Sqrt combination [model] (3)", alpha**0.5, (1 - alpha)**0.5),
    lambda alpha: ("Linear convex combination (4)", alpha, 1 - alpha),
]
In [613]:
def plot_interpolations(interpolation_list, suptitle=None):
    n = len(interpolation_list)
    alpha = np.linspace(1, 0, 100)
    fig, axes = plt.subplots(1, n, figsize=(FIG_SIZE*n, FIG_SIZE), constrained_layout=True)

    if suptitle is not None:
        fig.suptitle(suptitle)

    for i, (interpolation_fct, ax) in enumerate(zip(interpolation_list, axes)):
        title, x_coef, eps_coef = interpolation_fct(alpha)
        ax.set_title(title)
        ax.set_yticks([0, 0.5, 1], ["$x_0$", "$\\frac{x_0 + \epsilon}{2}$", "$\epsilon$"], fontsize=14)
        ax.plot(alpha, [0.5] * len(alpha), "--", color="black", alpha=0.3)
        ax.plot(alpha, x_coef, label="$x_0$" if i == 0 else None)
        ax.plot(alpha, eps_coef, label="$\epsilon$" if i == 0 else None)
        ax.fill_between(alpha, eps_coef / (x_coef + eps_coef), alpha=0.1, color="black", label="$x_t$" if i == 0 else None)
        ax.set_xlabel("$\overline{\\alpha_t}$")
        ax.invert_xaxis()

    fig.legend(loc="center right", bbox_to_anchor=(1.1, 0.5), fontsize=14)

plot_interpolations(interpolation_list, suptitle="Evolution of $x_t$ with linear scheduling of $\overline{\\alpha_t}$")
In [614]:
def compare_interpolations(interpolation_list, suptitle=None):
    alpha = np.linspace(0, 1, 100)

    if suptitle is not None:
        plt.suptitle(suptitle)

    plt.axhline(0.5, linestyle="--", color="black", alpha=0.3)
    for interpolation_fct in interpolation_list:
        title, x_coef, eps_coef = interpolation_fct(alpha)
        plt.plot(alpha, eps_coef / (x_coef + eps_coef), label=title)
    plt.yticks([0, 0.5, 1], ["$x_0$", "$\\frac{x_0 + \epsilon}{2}$", "$\epsilon$"])
    plt.legend(loc="upper right", bbox_to_anchor=(1.1, 1))
    plt.xlabel("$\overline{\\alpha_t}$")
    plt.gca().invert_xaxis()

compare_interpolations(interpolation_list, suptitle="Comparison of $x_t$ with different linear scheduling of $\overline{\\alpha_t}$")
In [615]:
def plot_transformations(interpolation_list, suptitle=None, include_custom_alpha_bar_list=False, n_steps=10):
    m = len(interpolation_list)
    fig, axes = plt.subplots(m, n_steps, figsize=(n_steps * FIG_SIZE, m * FIG_SIZE), constrained_layout=True)

    if suptitle is not None:
        fig.suptitle(suptitle, fontsize=20)

    eps = torch.normal(0, 1, size=X.shape)
    for i, interpolation_fct in enumerate(interpolation_list):
        if include_custom_alpha_bar_list:
            assert isinstance(interpolation_fct, tuple), "interpolation_fct must be a tuple if include_custom_alpha_bar_list is True"
            alpha_bar_list, interpolation_fct = interpolation_fct
            alpha_bar_list = alpha_bar_list[::len(alpha_bar_list) // n_steps]
        else:
            alpha_bar_list = torch.linspace(1, 0, 10)
        for j, alpha_bar in enumerate(alpha_bar_list):
            title, x_coef, eps_coef,  = interpolation_fct(alpha_bar)
            if j == 0:
                axes[i, j].set_ylabel(title, fontsize=14)
            if i == 0 or include_custom_alpha_bar_list:
                axes[i, j].set_title(f"$\overline{{\\alpha_t}}={alpha_bar:.1f}$", fontsize=18)
            scatter(x_coef*X + eps_coef*eps, ax=axes[i, j], show_axes=True)

plot_transformations(interpolation_list, suptitle="Transformation of $x_t$ using different linear scheduling of $\overline{\\alpha_t}$")

Remind that we are using a linear increasing sequence of $\overline{\alpha_t}$ and NOT of the $\alpha_t$.

This experiment allow us to understand the impact of the $\overline{\alpha_t}$ on the model. Using this framework we should be able to easly compare different sceduling strategies and compare their impact on the transformation of the $x_t$.

Testing different scheduling of $\alpha$¶

In [616]:
T = 1000
In [617]:
def compare_alpha_bar_evolutions(interpolation_list, suptitle=None):
    alpha = np.linspace(0, 1, T)
    if suptitle is not None:
        plt.suptitle(suptitle)
    for interpolation_fct in interpolation_list:
        assert isinstance(interpolation_fct, tuple), "interpolation_fct must be a tuple"
        alpha_bar_list, interpolation_fct = interpolation_fct
        title, *_ = interpolation_fct(alpha)
        plt.plot(np.linspace(0, 1, T), alpha_bar_list, label=title)
    plt.legend(loc="upper right", bbox_to_anchor=(1.1, 1))
    plt.xlabel("diffusion step ($t / T$)")
    plt.ylabel("$\overline{\\alpha_t}$")

Fixed $\alpha$ scheduling¶

In [618]:
def get_fixed_alpha_bars(cst):
    return (np.ones(T) * cst).cumprod(0)


interpolation_list = [
    (get_fixed_alpha_bars(1 - 1e-1), lambda alpha_bar: ("$\\beta_t=10^{-1}$", alpha_bar**0.5, (1 - alpha_bar)**0.5)),
    (get_fixed_alpha_bars(1 - 1e-2), lambda alpha_bar: ("$\\beta_t=10^{-2}$", alpha_bar**0.5, (1 - alpha_bar)**0.5)),
    (get_fixed_alpha_bars(1 - 1e-2/2), lambda alpha_bar: ("$\\beta_t=\\frac{10^{-2}}{2}$", alpha_bar**0.5, (1 - alpha_bar)**0.5)),
    (get_fixed_alpha_bars(1 - 1e-2/3), lambda alpha_bar: ("$\\beta_t=\\frac{10^{-2}}{3}$", alpha_bar**0.5, (1 - alpha_bar)**0.5)),
    (get_fixed_alpha_bars(1 - 1e-2/4), lambda alpha_bar: ("$\\beta_t=\\frac{10^{-2}}{4}$", alpha_bar**0.5, (1 - alpha_bar)**0.5)),
    (get_fixed_alpha_bars(1 - 1e-3), lambda alpha_bar: ("$\\beta_t=10^{-3}$", alpha_bar**0.5, (1 - alpha_bar)**0.5)),
    (get_fixed_alpha_bars(1 - 1e-4), lambda alpha_bar: ("$\\beta_t=10^{-4}$", alpha_bar**0.5, (1 - alpha_bar)**0.5)),
]
In [619]:
plot_transformations(interpolation_list, include_custom_alpha_bar_list=True, suptitle="Transformation of $x_t$ with fixed $\\alpha$ scheduling")
In [620]:
compare_alpha_bar_evolutions(interpolation_list, suptitle="Evolution of $\overline{\\alpha_t}$ with fixed $\\alpha$ scheduling")

Linear $\alpha$ scheduling¶

In [621]:
def get_linear_alpha_bars(start, end):
    return (np.linspace(start, end, T)).cumprod(0)


interpolation_list = [
    (get_linear_alpha_bars(1 - 1e-4, 1 - 0.02), lambda alpha_bar: ("$\\beta_t \in [10^{-4}, 0.02]$", alpha_bar**0.5, (1 - alpha_bar)**0.5)), # from "Denoising Diffusion Probabilistic Models"
    (get_linear_alpha_bars(1 - 1e-5, 1 - 1e-1), lambda alpha_bar: ("$\\beta_t \in [10^{-5}, 10^{-1}]$", alpha_bar**0.5, (1 - alpha_bar)**0.5)),
    (get_linear_alpha_bars(1 - 1e-5, 1 - 1e-2), lambda alpha_bar: ("$\\beta_t \in [10^{-5}, 10^{-2}]$", alpha_bar**0.5, (1 - alpha_bar)**0.5)),
    (get_linear_alpha_bars(1 - 1e-5, 1 - 1e-3), lambda alpha_bar: ("$\\beta_t \in [10^{-5}, 10^{-3}]$", alpha_bar**0.5, (1 - alpha_bar)**0.5)),
    (get_linear_alpha_bars(1 - 1e-4, 1 - 1e-1), lambda alpha_bar: ("$\\beta_t \in [10^{-4}, 10^{-1}]$", alpha_bar**0.5, (1 - alpha_bar)**0.5)),
    (get_linear_alpha_bars(1 - 1e-4, 1 - 1e-2), lambda alpha_bar: ("$\\beta_t \in [10^{-4}, 10^{-2}]$", alpha_bar**0.5, (1 - alpha_bar)**0.5)),
    (get_linear_alpha_bars(1 - 1e-4, 1 - 1e-3), lambda alpha_bar: ("$\\beta_t \in [10^{-4}, 10^{-3}]$", alpha_bar**0.5, (1 - alpha_bar)**0.5)),
]
In [622]:
plot_transformations(interpolation_list, include_custom_alpha_bar_list=True, suptitle="Transformation of $x_t$ with linear $\\alpha$ scheduling")
In [623]:
compare_alpha_bar_evolutions(interpolation_list, suptitle="Evolution of $\overline{\\alpha_t}$ with linear $\\alpha$ scheduling")

Cosine $\alpha$ scheduling¶

In [624]:
def get_cosine_alpha_bars(s):
    def f(t):
        return np.cos((t/T + s)/(1 + s) * np.pi/2)**2
    return f(np.linspace(0, T, T))


interpolation_list = [
    (get_cosine_alpha_bars(0.008), lambda alpha_bar: ("$s = 0.008$", alpha_bar**0.5, (1 - alpha_bar)**0.5)),
    (get_cosine_alpha_bars(1e-1), lambda alpha_bar: ("$s = 10^{-1}$", alpha_bar**0.5, (1 - alpha_bar)**0.5)),
    (get_cosine_alpha_bars(1e-2), lambda alpha_bar: ("$s = 10^{-2}$", alpha_bar**0.5, (1 - alpha_bar)**0.5)),
    (get_cosine_alpha_bars(1e-3), lambda alpha_bar: ("$s = 10^{-3}$", alpha_bar**0.5, (1 - alpha_bar)**0.5)),
    (get_cosine_alpha_bars(1e-4), lambda alpha_bar: ("$s = 10^{-4}$", alpha_bar**0.5, (1 - alpha_bar)**0.5)),
    (get_cosine_alpha_bars(1e-8), lambda alpha_bar: ("$s = 10^{-8}$", alpha_bar**0.5, (1 - alpha_bar)**0.5)),
]
In [625]:
plot_transformations(interpolation_list, include_custom_alpha_bar_list=True, suptitle="Transformation of $x_t$ with cosine $\\alpha$ scheduling")
In [626]:
compare_alpha_bar_evolutions(interpolation_list, suptitle="Evolution of $\overline{\\alpha_t}$ with cosine $\\alpha$ scheduling")

Comparing best $\alpha$ schedules¶

In [627]:
selected_interpolation_list = [
    (get_fixed_alpha_bars(1 - 1e-2/4), lambda alpha_bar: ("fixed", alpha_bar**0.5, (1 - alpha_bar)**0.5)),
    (get_linear_alpha_bars(1 - 1e-5, 1 - 1e-2), lambda alpha_bar: ("linear", alpha_bar**0.5, (1 - alpha_bar)**0.5)),
    (get_cosine_alpha_bars(1e-2), lambda alpha_bar: ("cosine", alpha_bar**0.5, (1 - alpha_bar)**0.5)),
]

selected_interpolation_list_using_custom_formulation = [
    (get_fixed_alpha_bars(1 - 1e-2/4), lambda alpha_bar: ("fixed", (1 - (1 - alpha_bar)**2)**0.5, 1 - alpha_bar)),
    (get_linear_alpha_bars(1 - 1e-5, 1 - 1e-2), lambda alpha_bar: ("linear", (1 - (1 - alpha_bar)**2)**0.5, 1 - alpha_bar)),
    (get_cosine_alpha_bars(1e-2), lambda alpha_bar: ("cosine", (1 - (1 - alpha_bar)**2)**0.5, 1 - alpha_bar)),
]
In [628]:
plot_transformations(selected_interpolation_list, include_custom_alpha_bar_list=True, suptitle="Transformation of $x_t$ with selected $\\alpha$ scheduling (using sqrt combination)")
In [629]:
plot_transformations(selected_interpolation_list_using_custom_formulation, include_custom_alpha_bar_list=True, suptitle="Transformation of $x_t$ with selected $\\alpha$ scheduling (using sqrt combination)")
In [630]:
compare_alpha_bar_evolutions(selected_interpolation_list, suptitle="Evolution of $\overline{\\alpha_t}$ with selected $\\alpha$ scheduling")
In [631]:
def forward(X_0, beta_scheduler):
    beta_scheduler = torch.cat([torch.zeros(1), beta_scheduler])  # add beta_0
    T = len(beta_scheduler)
    alpha_bar_list = (1 - beta_scheduler).cumprod(dim=0)

    def qx_t(t):
        alpha_bar = alpha_bar_list[t]
        return alpha_bar**0.5 * X_0 + (1 - alpha_bar)**0.5 * torch.normal(0, 1, X_0.shape)

    X_s = [qx_t(t) for t in range(T)]
    return X_s, alpha_bar_list